#!/bin/bash
source "${LMD_BASE_INSTALL_SCRIPT_DIR}/global/sys_check.sh"
check_flash_attn() {        
    if current_os_is_windows; then
        echo "windows os."
        if command -v nvidia-smi &>/dev/null; then
            echo "Found NVIDIA GPU Driver. install flash_attn"
            # FLASH_ATTN_VERSION=2.7.4
            # PYTHON_VERSION=310
            local CP_PY_VERSION=312
            # TORCH_VERSION=2.7.0

            if conda_run_cmd pip show triton-windows &> /dev/null; then
                echo "triton-windows is installed"
            else
                echo "start install triton-windows"
                conda_run_pip_install triton-windows==3.4.0.post21
            fi
        
            echo "Found NVIDIA drivers and cuda"
            if conda_run_cmd pip show flash_attn &> /dev/null; then
                echo "flash_attn is installed"
            else
                echo "start install flash_attn"
                conda_run_pip_install ${GIT_HOST}/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp${CP_PY_VERSION}-cp${CP_PY_VERSION}-win_amd64.whl
            fi
        fi
        
    fi
}

check_flash_attn

